#include "rsa.h"

#include <string.h>
#include <assert.h>

#define ACCURACY 5
#define SINGLE_MAX 10000
#define EXPONENT_MAX 1000
#define BUF_SIZE 1024

//.m_bytes超过1时加密数字会发生错误(溢出)?
struct Key key ={
    .m_public = {
        .m_exponent = 823,
        .m_modulus = 12225293, 
        .m_bytes = 1,   //测试过一次加密超过一个字节会发生错误
    },
    .m_private = {
        .m_exponent = 4468123,
        .m_modulus = 12225293,
        .m_bytes = 1,   //
    },
};
/**
 * Computes a^b mod c
 */
int modpow(long long a, long long b, int c) {
	int res = 1;
	while(b > 0) {
		/* Need long multiplication else this will overflow... */
		if(b & 1) {
			res = (res * a) % c;
		}
		b = b >> 1;
		a = (a * a) % c; /* Same deal here */
	}
	return res;
}

/**
 * Computes the Jacobi symbol, (a, n)
 */
int jacobi(int a, int n) {
	int twos, temp;
	int mult = 1;
	while(a > 1 && a != n) {
		a = a % n;
		if(a <= 1 || a == n) break;
		twos = 0;
		while(a % 2 == 0 && ++twos) a /= 2; /* Factor out multiples of 2 */
		if(twos > 0 && twos % 2 == 1) mult *= (n % 8 == 1 || n % 8 == 7) * 2 - 1;
		if(a <= 1 || a == n) break;
		if(n % 4 != 1 && a % 4 != 1) mult *= -1; /* Coefficient for flipping */
		temp = a;
		a = n;
		n = temp;
	}
	if(a == 0) return 0;
	else if(a == 1) return mult;
	else return 0; /* a == n => gcd(a, n) != 1 */
}

/**
 * Check whether a is a Euler witness for n
 */
int solovayPrime(int a, int n) {
	int x = jacobi(a, n);
	if(x == -1) x = n - 1;
	return x != 0 && modpow(a, (n - 1)/2, n) == x;
}

/**
 * Test if n is probably prime, using accuracy of k (k solovay tests)
 */
int probablePrime(int n, int k) {
	if(n == 2) return 1;
	else if(n % 2 == 0 || n == 1) return 0;
	while(k-- > 0) {
		if(!solovayPrime(rand() % (n - 2) + 2, n)) return 0;
	}
	return 1;
}

/**
 * Find a random (probable) prime between 3 and n - 1, this distribution is
 * nowhere near uniform, see prime gaps
 */
int randPrime(int n) {
	int prime = rand() % n;
	n += n % 2; /* n needs to be even so modulo wrapping preserves oddness */
	prime += 1 - prime % 2;
	while(1) {
		if(probablePrime(prime, ACCURACY)) return prime;
		prime = (prime + 2) % n;
	}
}

/**
 * Compute gcd(a, b)
 */
int gcd(int a, int b) {
	int temp;
	while(b != 0) {
		temp = b;
		b = a % b;
		a = temp;
	}
	return a;
}

/**
 * Find a random exponent x between 3 and n - 1 such that gcd(x, phi) = 1,
 * this distribution is similarly nowhere near uniform
 */
int randExponent(int phi, int n) {
	int e = rand() % n;
	while(1) {
		if(gcd(e, phi) == 1) return e;
		e = (e + 1) % n;
		if(e <= 2) e = 3;
	}
}

/**
 * Compute n^-1 mod m by extended euclidian method
 */
int inverse(int n, int modulus) {
	int a = n, b = modulus;
	int x = 0, y = 1, x0 = 1, y0 = 0, q, temp;
	while(b != 0) {
		q = a / b;
		temp = a % b;
		a = b;
		b = temp;
		temp = x; x = x0 - q * x; x0 = temp;
		temp = y; y = y0 - q * y; y0 = temp;
	}
	if(x0 < 0) x0 += modulus;
	return x0;
}

/**
 * Read the file fd into an array of bytes ready for encryption.
 * The array will be padded with zeros until it divides the number of
 * bytes encrypted per block. Returns the number of bytes read.
 */
int readFile(FILE* fd, char** buffer, int bytes) {
	int len = 0, cap = BUF_SIZE, r;
	char buf[BUF_SIZE];
	*buffer = malloc(BUF_SIZE * sizeof(char));
	while((r = fread(buf, sizeof(char), BUF_SIZE, fd)) > 0) {
		if(len + r >= cap) {
			cap *= 2;
			*buffer = realloc(*buffer, cap);
		}
		memcpy(&(*buffer)[len], buf, r);
		len += r;
	}
	/* Pad the last block with zeros to signal end of cryptogram. An additional block is added if there is no room */
	if(len + bytes - len % bytes > cap) *buffer = realloc(*buffer, len + bytes - len % bytes);
	do {
		(*buffer)[len] = '\0';
		len++;
	}
	while(len % bytes != 0);
	return len;
}

/**
 * Encode the message m using public exponent and modulus, c = m^e mod n
 */
int encode(int m, int e, int n) {
	return modpow(m, e, n);
}

/**
 * Decode cryptogram c using private exponent and public modulus, m = c^d mod n
 */
int decode(int c, int d, int n) {
	return modpow(c, d, n);
}


/**
 * @brief  加密字符串message
 * @note   len必须能被bytes整除
 * @param  len:     字符串的长度(单位：字节)
 * @param  bytes:   块大小，例如bytes=3,则每三个字节作为一个加密的单位(组成一个整数)
 * @param  message: 字符串地址
 * @param  exponent:
 * @param  modulus: 与exponent上面一起组成公钥
 *      m = message[i] + message[i+1]*128 + message[i+bytes-1]*128^(bytes-1)
 *      encoded[i] = m^exponent mod modulus 
 * @retval          加密后的整型数组,数组长度=len/bytes
 */
int* encodeMessage(int len, int bytes, char* message, int exponent, int modulus) {
	int *encoded = malloc((len/bytes) * sizeof(int));
	int x, i, j;
	for(i = 0; i < len; i += bytes) {
		x = 0;
		for(j = 0; j < bytes; j++) x += message[i + j] * (1 << (7 * j));
		encoded[i/bytes] = encode(x, exponent, modulus);
	}
	return encoded;
}

/**
 * Decode the cryptogram of given length, using the private key (exponent, modulus)
 * Each encrypted packet should represent "bytes" characters as per encodeMessage.
 * The returned message will be of size len * bytes.
 */
/**
 * @brief  解密字符串cryptogram
 * @note   
 * @param  len:         数据块长度
 * @param  bytes:       一个数据块的大小
 * @param  cryptogram:  加密后的字符串
 * @param  exponent:    
 * @param  modulus:     与exponent一起组成密钥
 * @retval 
 */
char* decodeMessage(int len, int bytes, int* cryptogram, int exponent, int modulus) {
	//int *decoded = malloc(len * bytes * sizeof(int));
    char *decoded_string = malloc(len*bytes*sizeof(char));
	int x, i, j;
	for(i = 0; i < len; i++) {
		x = decode(cryptogram[i], exponent, modulus);
		for(j = 0; j < bytes; j++) {
            decoded_string[i*bytes + j] = (x >> (7 * j)) % 128;
		}
	}
    //free(decode);
	return decoded_string;
}

/**
 * @brief  解密时间戳与position
 * @note   
 * @param  *buf: 
 * @param  packet_len: 
 * @param  &m_position: 
 * @param  &m_timestamp: 
 * @retval 成功返回，失败返回-1并不对参数做任何修改
 */
int decodeTimestamp(void *buf,u_int32_t packet_len,
        u_int32_t *m_position,int64_t *m_timestamp)
{
    int *encoded = buf;
    if(packet_len <16)
        return 0;
    char *decoded = decodeMessage(packet_len/sizeof(int), 
            key.m_private.m_bytes, 
            encoded, key.m_private.m_exponent, 
            key.m_private.m_modulus);
    memcpy(m_timestamp,decoded,8);
    memcpy(m_position,decoded+8,4);
    
    free(decoded);
    return 0;
}
/**
 * @brief  将时间戳以及位置加密,加密结果保存到buf指向的地址
 * @note   
 * @param  *buf:            保存加密字段的地址，该指针指向的地址应该是在在堆或栈中已经分配好的
 * @param  *m_position:     位置
 * @param  *m_timestamp:    时间戳
 * @retval 加密字段buf的长度
 */
int encodeTimestamp(void *buf,u_int32_t m_position,int64_t m_timestamp)
{
    assert(buf!=NULL);

    char encoded_string[12];
    memcpy(encoded_string,&m_timestamp,8);
    memcpy(encoded_string+8,&m_position,4);
    int *encoded = encodeMessage(12,key.m_public.m_bytes,encoded_string,
            key.m_public.m_exponent,key.m_public.m_modulus);

    memcpy(buf,encoded,12/key.m_public.m_bytes*sizeof(int));

    return 12/key.m_public.m_bytes*sizeof(int);
}